/*
 *   This program is free software; you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License version 2
 *   as published by the Free Software Foundation.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program; if not, write to the Free Software
 *   Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
 *
 *   Copyright (C) 2008  Benjamin Segovia <segovia.benjamin@gmail.com>
 */
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <math.h>

#ifdef _WIN32
    #define WINDOWS_LEAN_AND_MEAN
    #include <windows.h>
#endif
#include <cuda_gl_interop.h>
#include <cutil.h>

extern "C" void init_cuda(int argc, char **argv);
extern "C" void cuda_shadowed(int pbo_out, void *cam, void *aabb, float *pos, int w, int h);
extern "C" void cuda_unshadowed(int pbo_out, void *cam, void *aabb, float *pos, int w, int h);
extern "C" void cuda_noshading(int pbo_out, void *cam, void *aabb, float *pos, int w, int h);
extern "C" void pbo_register(int pbo);
extern "C" void pbo_unregister(int pbo);
extern "C" void node_texture_bind(const void *data, size_t size);
extern "C" void rt_tri_texture_bind(const void *data, size_t size);
extern "C" void id_texture_bind(const void *data, size_t size);
extern "C" void normal_texture_bind(const void *data, size_t size);

/* Quick defines and typedef for the cuda part of the computations */
#define VEC_ZERO vec_t(0.f, 0.f, 0.f)
#define HUGE 1e10f
#define DEGRAD(a)    ((a) * (float(M_PI) / 180.f))
#define RADDEG(a)    ((a) * (180.f/ float(M_PI)))
#ifndef M_PI
    #define M_PI    3.14159265358979323846
#endif
typedef unsigned int uint32_t;
#define THREAD_W 2
#define THREAD_H 32
#define THREAD_N THREAD_W * THREAD_H

/* Texture for kd-tree nodes */
texture<uint2, 1, cudaReadModeElementType> node_tex;

/* Texture for triangles */
texture<uint4, 1, cudaReadModeElementType> tri_tex;

/* Texture for ids */
texture<unsigned int, 1, cudaReadModeElementType> id_tex;

/* Texture for normals */
texture<float, 1, cudaReadModeElementType> n_tex;

/* The position of the light source */
__constant__ float3 cu_lpos;

/***************************************************************************//**
 * Initialize the node texture
/******************************************************************************/
void node_texture_bind(const void *data, size_t size)
{
    node_tex.addressMode[0] = cudaAddressModeWrap;
    node_tex.addressMode[1] = cudaAddressModeWrap;
    node_tex.filterMode = cudaFilterModePoint;
    node_tex.normalized = false;
    CUDA_SAFE_CALL(cudaBindTexture(0, node_tex, data, size));
}

/***************************************************************************//**
 * Initialize the triangle texture
/******************************************************************************/
void rt_tri_texture_bind(const void *data, size_t size)
{
    node_tex.addressMode[0] = cudaAddressModeWrap;
    node_tex.addressMode[1] = cudaAddressModeWrap;
    node_tex.filterMode = cudaFilterModePoint;
    node_tex.normalized = false;
    CUDA_SAFE_CALL(cudaBindTexture(0, tri_tex, data, size));
}

/***************************************************************************//**
 * Initialize the id texture
/******************************************************************************/
void id_texture_bind(const void *data, size_t size)
{
    node_tex.addressMode[0] = cudaAddressModeWrap;
    node_tex.addressMode[1] = cudaAddressModeWrap;
    node_tex.filterMode = cudaFilterModePoint;
    node_tex.normalized = false;
    CUDA_SAFE_CALL(cudaBindTexture(0, id_tex, data, size));
}

/***************************************************************************//**
 * Initialize the normal texture
/******************************************************************************/
void normal_texture_bind(const void *data, size_t size)
{
    n_tex.addressMode[0] = cudaAddressModeWrap;
    n_tex.addressMode[1] = cudaAddressModeWrap;
    n_tex.filterMode = cudaFilterModePoint;
    n_tex.normalized = false;
    CUDA_SAFE_CALL(cudaBindTexture(0, n_tex, data, size));
}

/***************************************************************************//**
 * Float 3D vector
/******************************************************************************/
struct vec_t {
    float x,y,z;
    __device__ vec_t() {}
    __device__ vec_t(const float a, const float b, const float c) : x(a), y(b), z(c) {}
    __device__ vec_t(const float a) : x(a), y(a), z(a) {}
    __device__ vec_t operator+(const vec_t &v) const { return vec_t(x+v.x,y+v.y,z+v.z); }
    __device__ vec_t operator-(const vec_t &v) const { return vec_t(x-v.x,y-v.y,z-v.z); }
    __device__ vec_t operator-() const { return vec_t(-x,-y,-z); }
    __device__ vec_t operator*(const float d) const { return vec_t(x*d,y*d,z*d); }
    __device__ vec_t cross(const vec_t &v) const { return vec_t(y*v.z-z*v.y,z*v.x-x*v.z,x*v.y-y*v.x); }
    __device__ vec_t normalize() const { return *this * (1.f/sqrtf(magsqr())); }
    __device__ float norm() const { return sqrtf(magsqr()); }
    __device__ float dot(const vec_t &v) const { return x*v.x+y*v.y+z*v.z; }
    __device__ float magsqr() const { return dot(*this); }
    __device__ float get_min() const { return fminf(fminf(x,y),z); }
    __device__ float get_max() const { return fmaxf(fmaxf(x,y),z); }
    __device__ vec_t perm_x() const { return vec_t(x, y, z); }
    __device__ vec_t perm_y() const { return vec_t(y, z, x); }
    __device__ vec_t perm_z() const { return vec_t(z, x, y); }
};

/***************************************************************************//**
 * Integer 2D point
/******************************************************************************/
struct point_t {
    struct { int x,y; };
    __device__ point_t() {}
    __device__ point_t(const int a, const int b) : x(a), y(b) {}
    __device__ int area() const { return x*y; }
    __device__ point_t operator+(const point_t rhs) const { return point_t(x+rhs.x, y+rhs.y); }
    __device__ point_t operator-(const point_t rhs) const { return point_t(x-rhs.x, y-rhs.y); }
    __device__ point_t operator+(const int rhs) const { return point_t(x+rhs, y+rhs); }
    __device__ point_t operator-(const int rhs) const { return point_t(x-rhs, y-rhs); }
    __device__ point_t operator*(const int rhs) const { return point_t(x*rhs, y*rhs); }
    __device__ point_t operator/(const int rhs) const { return point_t(x/rhs, y/rhs); }
};

/***************************************************************************//**
 * Clamp x on [a, b]
/******************************************************************************/
__device__ float clamp(float x, float a, float b)
{
    return fmax(a, fmin(b, x));
}

/***************************************************************************//**
 * Convert floating point rgb color to 8-bit integer
/******************************************************************************/
__device__ static inline int frgb_to_int(vec_t v)
{
    const float r = clamp(v.x, 0.0f, 255.0f);
    const float g = clamp(v.y, 0.0f, 255.0f);
    const float b = clamp(v.z, 0.0f, 255.0f);
    return (int(b)<<16) | (int(g)<<8) | int(r);
}

/***************************************************************************//**
 * The ray structure
/******************************************************************************/
struct ray_t {
    vec_t pos, dir;
    __device__ inline void set_pos(const vec_t &o) { pos = o; }
    __device__ inline void set_dir(const vec_t &d) { dir = d; }
    __device__ inline float2 get_dir_pos(uint32_t i) const {
        float2 ret;
        switch(i) {
            case 0:     ret.x = pos.x; ret.y = dir.x; return ret;
            case 1:     ret.x = pos.y; ret.y = dir.y; return ret;
            default:    ret.x = pos.z; ret.y = dir.z; return ret;
        }
    }
};

/***************************************************************************//**
 * The triangle used for the intersection computations
/******************************************************************************/
struct wald_tri_t {
    uint4 internal0, internal1, internal2;
    __device__ uint32_t k() const { return internal0.x; }
    __device__ float n_u() const { return(int_as_float(internal0.y)); }
    __device__ float n_v() const { return(int_as_float(internal0.z)); }
    __device__ float n_d() const { return(int_as_float(internal0.w)); }
    __device__ float vert_ku() const { return(int_as_float(internal1.x)); }
    __device__ float vert_kv() const { return(int_as_float(internal1.y)); }
    __device__ float b_nu() const { return(int_as_float(internal1.z)); }
    __device__ float b_nv() const { return(int_as_float(internal1.w)); }
    __device__ float c_nu() const { return(int_as_float(internal2.x)); }
    __device__ float c_nv() const { return(int_as_float(internal2.y)); }
    __device__ uint32_t id() const { return internal2.z; }
    __device__ uint32_t matid() const { return internal2.w; }

    struct perm_t { vec_t dir, pos; };
    __device__ inline perm_t get_perm(const ray_t &ray) const {
        perm_t perm;
        uint32_t const axis = k();
        switch(axis) {
            case 0:
                perm.dir = ray.dir.perm_x();
                perm.pos = ray.pos.perm_x();
                return perm;
            case 1:
                perm.dir = ray.dir.perm_y();
                perm.pos = ray.pos.perm_y();
                return perm;
            default:
                perm.dir = ray.dir.perm_z();
                perm.pos = ray.pos.perm_z();
                return perm;
        }
    }
};

/***************************************************************************//**
 * The kd-tree structure
/******************************************************************************/
namespace kdtree {
    struct node_t {
        enum {
            mask_leaf = (int) 1ul<<31, mask_list = ~mask_leaf,
            mask_axis = 3, mask_children = ~mask_axis
        };
        uint2 internal;
        __device__ inline uint32_t is_leaf() const { return offset_flag() & (uint32_t) mask_leaf; }
        __device__ inline uint32_t is_node() const { return !is_leaf(); }
        __device__ inline uint32_t get_list() const { return offset_flag() & (uint32_t) mask_list; }
        __device__ inline uint32_t get_axis() const { return dim_offset_flag() & (uint32_t) mask_axis; }
        __device__ inline uint32_t get_offset() const { return dim_offset_flag() & (uint32_t) mask_children; }
        __device__ inline uint32_t offset_flag() const {return internal.x;}
        __device__ inline uint32_t tri_count() const {return internal.y;}
        __device__ inline uint32_t dim_offset_flag() const {return internal.x;}
        __device__ inline float split_coord() const {return int_as_float(internal.y);}
    };
    typedef uint32_t tri_id_t;
}

/***************************************************************************//**
 * The structure used to store the intersection point
/******************************************************************************/
struct hit_t {
    float t, u, v;
    uint32_t id;
    hit_t() : t(HUGE), id(~0u) {}
    __device__ explicit hit_t(const float tmax) : t(tmax), id((uint32_t)(-1)) {}
};

/***************************************************************************//**
 * The trace stack used when traversing the kd-tree
/******************************************************************************/
struct trace_stack_t {
    struct trace_t {
        uint32_t node;
        float tnear;
        float tfar;
    };
    trace_t t[32];
    int idx;
    __device__ inline trace_t& get() { return t[idx]; }
    trace_stack_t() : idx(0) {}
    __device__ inline void reset() { idx = 0; }
    __device__ inline int pop() { return --idx >= 0; }
    __device__ inline uint32_t& node() { return get().node; }
    __device__ inline float& t_near() { return get().tnear; }
    __device__ inline float& t_far() { return get().tfar; }
};

/***************************************************************************//**
 * Simple structure to sample the screen
/******************************************************************************/
struct sampler_t {
    vec_t top, dx, dy;
    __device__ vec_t map(const point_t &screen) const {
        return vec_t(top + dx*float(screen.x) + dy*float(screen.y));
    }
};

/***************************************************************************//**
 * Simple camera structure
/******************************************************************************/
struct camera_t {
    struct {
        vec_t eye, dir, up, right;
        float fovx;
        int world_up_index;
        sampler_t sampler;
    };
    __device__ inline void look_at(const vec_t &target, const int up_idx = -1);
    __device__ inline void set_fovx(const float degree) { fovx = DEGRAD(degree) * .5f; }
    __device__ inline void set_eye(const vec_t &v) { eye = v; }
    __device__ inline float get_fovx() const { return RADDEG(fovx)*2.f; }
    __device__ inline const vec_t &get_eye() const { return eye; }
    __device__ inline const vec_t &get_up() const { return up; }
    __device__ inline const vec_t &get_dir() const { return dir; }
    __device__ inline const vec_t &get_right() const { return right; }
    __device__ inline void set_world_up_index(int idx) { world_up_index = idx; }
    __device__ inline int get_world_up_index() const { return world_up_index; }
};

/* The camera used to generate the primary rays */
__constant__ camera_t cu_cam;

/***************************************************************************//**
 * The axis-aligned bounding box
/******************************************************************************/
struct aabb_t {
    float xmin, ymin, zmin;
    float xmax, ymax, zmax;
};

/* The bounding box of the scene */
__constant__ aabb_t cu_aabb;

/***************************************************************************//**
 * Perform the Kay-Kajiya ray-box intersection
/******************************************************************************/
__device__ static inline uint32_t
intersect_ray_box(const aabb_t &box, const ray_t &ray, float &tmin, float &tmax)
{
    float l1 = __fdividef(box.xmin - ray.pos.x, ray.dir.x);
    float l2 = __fdividef(box.xmax - ray.pos.x, ray.dir.x);
    tmin = fmaxf(fminf(l1,l2), tmin);
    tmax = fminf(fmaxf(l1,l2), tmax);
    l1 = __fdividef(box.ymin - ray.pos.y, ray.dir.y);
    l2 = __fdividef(box.ymax - ray.pos.y, ray.dir.y);
    tmin = fmaxf(fminf(l1,l2), tmin);
    tmax = fminf(fmaxf(l1,l2), tmax);
    l1 = __fdividef(box.zmin - ray.pos.z, ray.dir.z);
    l2 = __fdividef(box.zmax - ray.pos.z, ray.dir.z);
    tmin = fmaxf(fminf(l1,l2), tmin);
    tmax = fminf(fmaxf(l1,l2), tmax);
    return ((tmax >= tmin) & (tmax >= 0.f));
}

/***************************************************************************//**
 * Perform the intersection between one ray and one triangle
/******************************************************************************/
static __device__ inline void
intersect_ray_tri(
    uint32_t id,
    const ray_t &ray, hit_t &hit,
    float t_near, float t_far)
{
    wald_tri_t tri;
    tri.internal0 = tex1Dfetch(tri_tex, 3 * id);
    wald_tri_t::perm_t p = tri.get_perm(ray);
    const float dot = (tri.n_d() - p.pos.x - tri.n_u()*p.pos.y - tri.n_v()*p.pos.z);
    const float denum = (p.dir.x + tri.n_u()*p.dir.y + tri.n_v()*p.dir.z);
    const float t = __fdividef(dot, denum);
    if ((hit.t <= t) | (t < t_near) | t > t_far) return;
    tri.internal1 = tex1Dfetch(tri_tex, 3 * id + 1);
    tri.internal2 = tex1Dfetch(tri_tex, 3 * id + 2);
    const float hu = p.pos.y + t*p.dir.y - tri.vert_ku();
    const float hv = p.pos.z + t*p.dir.z - tri.vert_kv();
    const float beta = hv*tri.b_nu() + hu*tri.b_nv();
    const float gamma = hu*tri.c_nu() + hv*tri.c_nv();
    if ((beta < 0.f) | (gamma < 0.f) | (beta + gamma > 1.f)) return;
    hit.t = t;
    hit.u = beta;
    hit.v = gamma;
    hit.id = tri.id();
}

/***************************************************************************//**
 * Perform the intersection between one ray and one triangle
/******************************************************************************/
static __device__ inline bool
shadow_ray_tri( uint32_t id, const ray_t &ray, float t_near, float t_far)
{
    wald_tri_t tri;
    tri.internal0 = tex1Dfetch(tri_tex, 3 * id);
    wald_tri_t::perm_t p = tri.get_perm(ray);
    const float dot = (tri.n_d() - p.pos.x - tri.n_u()*p.pos.y - tri.n_v()*p.pos.z);
    const float denum = (p.dir.x + tri.n_u()*p.dir.y + tri.n_v()*p.dir.z);
    const float t = __fdividef(dot, denum);
    if((t < t_near) | (t > t_far)) return false;
    tri.internal1 = tex1Dfetch(tri_tex, 3 * id + 1);
    tri.internal2 = tex1Dfetch(tri_tex, 3 * id + 2);
    const float hu = p.pos.y + t*p.dir.y - tri.vert_ku();
    const float hv = p.pos.z + t*p.dir.z - tri.vert_kv();
    const float beta = hv*tri.b_nu() + hu*tri.b_nv();
    const float gamma = hu*tri.c_nu() + hv*tri.c_nv();
    if ((beta < 0.f) | (gamma < 0.f) | (beta + gamma > 1.f)) return false;
    return true;
}


/***************************************************************************//**
 * Trace the ray inside the kd-tree
/******************************************************************************/
static __device__ inline void trace_kdtree(const ray_t &ray, hit_t &hit)
{
    /* First, intersect the bounding box of the scene */
    float t_near = 0.f, t_far = HUGE;
    if(!intersect_ray_box(cu_aabb, ray, t_near, t_far)) return;

    /* Then, intersect the kd-tree */
    trace_stack_t stack;
    stack.reset();
    kdtree::node_t node;
    node.internal = tex1Dfetch(node_tex, 0);

    for(;;) {
        /* Process the non-leaf nodes */
        while(node.is_node()) {
            const uint32_t bits = node.dim_offset_flag();
            const uint32_t axis = bits & 3u;
            const uint32_t off = (bits & (uint32_t) kdtree::node_t::mask_children) >> 3;
            const float split = node.split_coord();
            const float2 pos_dir = ray.get_dir_pos(axis);
            const float dir = pos_dir.y;
            const float d = __fdividef(split - pos_dir.x, dir);
            const uint32_t sign = dir >= 0.f;
            uint32_t idx = off + (sign^1);
            if (d < t_near) idx = off + (sign^0);
            else if(d <= t_far) {
                trace_stack_t::trace_t& trace = stack.get();
                trace.node = off + (sign^0);
                trace.tnear = d;
                trace.tfar = t_far;
                stack.idx++;
                t_far = d;
            }
            node.internal = tex1Dfetch(node_tex, idx);
        }

        /* Intersect the inner triangles of the leaf and exit if possible */
        const uint32_t idx = node.get_list(), count = node.tri_count();
        for (uint32_t i = 0; i < count; ++i) {
            uint32_t id = tex1Dfetch(id_tex, idx + i);
            intersect_ray_tri(id, ray, hit, t_near, t_far);
        }
        if(stack.pop() & (hit.t > t_far)) {
            trace_stack_t::trace_t& trace = stack.get();
            node.internal = tex1Dfetch(node_tex, trace.node);
            t_near = trace.tnear;
            t_far = trace.tfar;
        } else break;
    }
}

/***************************************************************************//**
 * Trace a shadow ray inside the kd-tree
/******************************************************************************/
static __device__ inline uint32_t shadow_kdtree(const ray_t &ray, float t_min, float t_max)
{
	// TODO ARTR: implement this method
	return 0;
}

/***************************************************************************//**
 * Perform the ray tracing and cast a shadow ray
/******************************************************************************/
__global__ void do_shadowed(int* g_odata, int w, int h)
{
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;
    const int bw = blockDim.x;
    const int bh = blockDim.y;
    const int x = blockIdx.x*bw + tx;
    const int y = blockIdx.y*bh + ty;
    hit_t hit;

    /* Set the current ray */
    const float cx = float(x), cy = float(y);
    vec_t scan_line(cu_cam.sampler.top + cu_cam.sampler.dx*cx + cu_cam.sampler.dy*cy);
    ray_t ray;
    ray.dir = scan_line.normalize();
    ray.set_pos(cu_cam.get_eye());

    /* Cast the ray */
    trace_kdtree(ray, hit);
    if(hit.id == ~0u) {
        g_odata[y * w + x] = 0u;
        return;
    }

    /* Get the normal and perform a simple diffuse shading */
    vec_t n;
    n.x = tex1Dfetch(n_tex, 3 * hit.id);
    n.y = tex1Dfetch(n_tex, 3 * hit.id + 1);
    n.z = tex1Dfetch(n_tex, 3 * hit.id + 2);

    /* Cast a shadow ray */
    ray_t sray;
    const vec_t lpos(cu_lpos.x, cu_lpos.y, cu_lpos.z);
    const vec_t dir(ray.pos + ray.dir * hit.t + n * 1e-3f - lpos);
    const float norm = dir.norm();
    const vec_t dirn = dir * __fdividef(1.f, norm);
    sray.set_pos(lpos);
    sray.set_dir(dirn);
    const bool shadowed = shadow_kdtree(sray, 0.f, norm);
    const vec_t r(1500.f * __fdividef(saturate(-n.dot(dirn)), norm * norm));
	const vec_t color = r* (shadowed ? 0.3f : 1.0f);
    g_odata[y * w + x] = frgb_to_int(color);
}

/***************************************************************************//**
 * Perform the ray tracing and make a simple shading
/******************************************************************************/
__global__ void do_unshadowed(int* g_odata, int w, int h)
{
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;
    const int bw = blockDim.x;
    const int bh = blockDim.y;
    const int x = blockIdx.x*bw + tx;
    const int y = blockIdx.y*bh + ty;
    hit_t hit;

    /* Set the current ray */
    const float cx = float(x), cy = float(y);
    vec_t scan_line(cu_cam.sampler.top + cu_cam.sampler.dx*cx + cu_cam.sampler.dy*cy);
    ray_t ray;
    ray.dir = scan_line.normalize();
    ray.set_pos(cu_cam.get_eye());

    /* Cast the ray */
    trace_kdtree(ray, hit);
    if(hit.id == ~0u) {
        g_odata[y * w + x] = 0u;
        return;
    }

    /* Get the normal and perform a simple diffuse shading */
    vec_t n;
    n.x = tex1Dfetch(n_tex, 3 * hit.id);
    n.y = tex1Dfetch(n_tex, 3 * hit.id + 1);
    n.z = tex1Dfetch(n_tex, 3 * hit.id + 2);

    const vec_t lpos(cu_lpos.x, cu_lpos.y, cu_lpos.z);
    const vec_t dir(ray.pos + ray.dir * hit.t + n * 1e-3f - lpos);
    const float norm = dir.norm();
    const vec_t dirn = dir * __fdividef(1.f, norm);
    const vec_t r(1500.f * __fdividef(saturate(-n.dot(dirn)), norm * norm));
    g_odata[y * w + x] = frgb_to_int(r);
}

/***************************************************************************//**
 * Initialize CUDA
/******************************************************************************/
void init_cuda(int argc, char **argv)
{
    CUT_DEVICE_INIT(argc, argv);
}

/***************************************************************************//**
 * Perform the ray tracing with no shading at all
/******************************************************************************/
__global__ void do_noshading(int* g_odata, int w, int h)
{
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;
    const int bw = blockDim.x;
    const int bh = blockDim.y;
    const int x = blockIdx.x*bw + tx;
    const int y = blockIdx.y*bh + ty;
    hit_t hit;

    /* Set the current ray */
    const float cx = float(x), cy = float(y);
    vec_t scan_line(cu_cam.sampler.top + cu_cam.sampler.dx*cx + cu_cam.sampler.dy*cy);
    ray_t ray;
    ray.dir = scan_line.normalize();
    ray.set_pos(cu_cam.get_eye());

    /* Cast the ray */
    trace_kdtree(ray, hit);
    bool const is_hit = (hit.id != ~0u);
    const float d = 255.f * 1.5f / hit.t;
    g_odata[y * w + x] = is_hit ? frgb_to_int(vec_t(d*0.25f, d*0.5f, d)) : 0u;
}

/***************************************************************************//**
 * Process with shadows
/******************************************************************************/
void cuda_shadowed(int pbo_out, void *cam, void *aabb, float *pos, int w, int h)
{
    int* out_data;
    CUDA_SAFE_CALL(cudaGLMapBufferObject((void **) &out_data, pbo_out));
    dim3 block(THREAD_W, THREAD_H, 1);
    dim3 grid(w / block.x, h / block.y, 1);
    CUDA_SAFE_CALL(cudaMemcpyToSymbol(cu_cam, cam, sizeof(camera_t)));
    CUDA_SAFE_CALL(cudaMemcpyToSymbol(cu_aabb, aabb, sizeof(aabb_t)));
    CUDA_SAFE_CALL(cudaMemcpyToSymbol(cu_lpos, pos, sizeof(float[3])));
    do_shadowed<<<grid, block, sizeof(camera_t)>>>(out_data, w, h);
    CUDA_SAFE_CALL(cudaGLUnmapBufferObject(pbo_out));
}

/***************************************************************************//**
 * Process with no shadow
/******************************************************************************/
void cuda_unshadowed(int pbo_out, void *cam, void *aabb, float *pos, int w, int h)
{
    int* out_data;
    CUDA_SAFE_CALL(cudaGLMapBufferObject((void **) &out_data, pbo_out));
    dim3 block(THREAD_W, THREAD_H, 1);
    dim3 grid(w / block.x, h / block.y, 1);
    CUDA_SAFE_CALL(cudaMemcpyToSymbol(cu_cam, cam, sizeof(camera_t)));
    CUDA_SAFE_CALL(cudaMemcpyToSymbol(cu_aabb, aabb, sizeof(aabb_t)));
    CUDA_SAFE_CALL(cudaMemcpyToSymbol(cu_lpos, pos, sizeof(float[3])));
    do_unshadowed<<<grid, block, sizeof(camera_t)>>>(out_data, w, h);
    CUDA_SAFE_CALL(cudaGLUnmapBufferObject(pbo_out));
}

/***************************************************************************//**
 * Process with no shading
/******************************************************************************/
void cuda_noshading(int pbo_out, void *cam, void *aabb, float *pos, int w, int h)
{
    int* out_data;
    CUDA_SAFE_CALL(cudaGLMapBufferObject((void **) &out_data, pbo_out));
    dim3 block(THREAD_W, THREAD_H, 1);
    dim3 grid(w / block.x, h / block.y, 1);
    CUDA_SAFE_CALL(cudaMemcpyToSymbol(cu_cam, cam, sizeof(camera_t)));
    CUDA_SAFE_CALL(cudaMemcpyToSymbol(cu_aabb, aabb, sizeof(aabb_t)));
    CUDA_SAFE_CALL(cudaMemcpyToSymbol(cu_lpos, pos, sizeof(float[3])));
    do_noshading<<<grid, block, sizeof(camera_t)>>>(out_data, w, h);
    CUDA_SAFE_CALL(cudaGLUnmapBufferObject(pbo_out));
}

/***************************************************************************//**
 * Register the pixel buffer object
/******************************************************************************/
void pbo_register(int pbo)
{
    CUDA_SAFE_CALL(cudaGLRegisterBufferObject(pbo));
}

/***************************************************************************//**
 * Unregister the pixel buffer object
/******************************************************************************/
void pbo_unregister(int pbo)
{
    CUDA_SAFE_CALL(cudaGLUnregisterBufferObject(pbo));
}
